import os
import yaml
import numpy as np
import argparse
from scipy.linalg import expm


def compute_U_from_A(A, gauge_group='U1'):
    """
    Compute U_mu = exp(i * A_mu).
    For U(1), A is scalar so U = exp(1j*A).
    For SU(N), A is an NxN matrix so U = expm(1j*A).
    """
    if gauge_group.upper() == 'U1':
        return np.exp(1j * A)
    else:
        # A should be shape (num_links, N, N)
        # Apply matrix exponential per link
        U = np.empty_like(A, dtype=complex)
        for idx, A_mat in enumerate(A):
            U[idx] = expm(1j * A_mat)
        return U


def main(config_path: str):
    # Load configuration
    with open(config_path) as f:
        cfg = yaml.safe_load(f)

    # Resolve data_dir relative to the directory of the config file
    base_dir = os.path.dirname(os.path.abspath(config_path))
    data_dir_cfg = cfg.get('data_dir', 'data')
    if os.path.isabs(data_dir_cfg):
        data_dir = data_dir_cfg
    else:
        data_dir = os.path.join(base_dir, data_dir_cfg)
    gauge_groups = cfg.get('gauge_groups', ['U1'])

    # For each gauge group load its corresponding A_mu array and compute
    # the link variables U_mu.  The arrays are expected to be saved
    # by ``compute_Amu.py`` as ``Amu_<G>.npy`` within the data_dir.
    for G in gauge_groups:
        # Determine the input file name for this group's A_mu
        Amu_name = f'Amu_{G}.npy'
        Amu_path = os.path.join(data_dir, Amu_name)
        if not os.path.exists(Amu_path):
            raise FileNotFoundError(f'Expected A_mu file {Amu_name} for gauge group {G} was not found in {data_dir}')
        Amu = np.load(Amu_path, allow_pickle=True)
        Umu = compute_U_from_A(Amu, gauge_group=G)
        out_name = f'Umu_{G}.npy'
        out_path = os.path.join(data_dir, out_name)
        np.save(out_path, Umu)
        print(f'Computed U_mu for {G}, saved to {out_path}')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Compute U_mu from A_mu for specified gauge groups')
    parser.add_argument('--config', default='config.yaml', help='Path to config file')
    args = parser.parse_args()
    main(args.config)